# -*- coding: utf-8 -*-
"""
Created on Sat Mar 21 16:31:53 2026

@author: 49489
"""

import numpy as np
from collections import Counter


drug = np.array([
    'AMK','CEF','CHL','CIP','CLA','ERY','FUS','GEN','LEV',
    'NAL','NIT','OXA','RIF','SPE','TET','TOB','TRI','VAN'
])

alpha = np.array([
    ['CHL','FUS','1'],
    ['CHL','TET','1'],
    ['FUS','TET','1'],
    ['FUS','LEV','-1'],
    ['CHL','CIP','-1'],
    ['CHL','NAL','-1'],
    ['LEV','TET','-1'],
    ['NAL','TET','-1'],
    ['CIP','TET','-1'],
    ['CHL','LEV','-1'],
    ['CLA','FUS','1'],
    ['ERY','FUS','1'],
    ['CLA','TET','1'],
    ['ERY','TET','1'],
    ['CHL','ERY','1'],
    ['CHL','CLA','1'],
    ['GEN','TET','1'],
    ['AMK','FUS','1'],
    ['TET','TOB','-1'],
    ['CHL','GEN','-1'],
    ['AMK','CHL','-1'],
    ['CHL','TOB','-1'],
    ['FUS','RIF','1'],
    ['FUS','NIT','-1'],
    ['CHL','NIT','-1'],
    ['NIT','TET','-1'],
    ['FUS','TRI','-1'],
    ['TET','TRI','-1'],
    ['CHL','TRI','-1'],
    ['FUS','VAN','1'],
    ['FUS','OXA','1'],
    ['TET','VAN','1'],
    ['OXA','TET','1'],
    ['CEF','FUS','1'],
    ['LEV','NAL','1'],
    ['CLA','NAL','-1'],
    ['CLA','LEV','-1'],
    ['CIP','ERY','-1'],
    ['ERY','NAL','-1'],
    ['CIP','CLA','-1'],
    ['NAL','SPE','-1'],
    ['LEV','SPE','-1'],
    ['CIP','SPE','-1'],
    ['CIP','GEN','-1'],
    ['GEN','NAL','-1'],
    ['LEV','TOB','-1'],
    ['NAL','TOB','-1'],
    ['CIP','TOB','-1'],
    ['LEV','NIT','-1'],
    ['NAL','TRI','-1'],
    ['LEV','RIF','-1'],
    ['CIP','NIT','-1'],
    ['NAL','NIT','-1'],
    ['CEF','CIP','-1'],
    ['NAL','VAN','-1'],
    ['LEV','OXA','-1'],
    ['NAL','OXA','-1'],
    ['CEF','NAL','-1'],
    ['CEF','LEV','-1'],
    ['AMK','ERY','-1'],
    ['ERY','TOB','-1'],
    ['AMK','SPE','-1'],
    ['GEN','SPE','-1'],
    ['SPE','TOB','-1'],
    ['SPE','TRI','-1'],
    ['NIT','SPE','-1'],
    ['CLA','TRI','-1'],
    ['ERY','OXA','-1'],
    ['CEF','SPE','-1'],
    ['AMK','RIF','-1'],
    ['GEN','RIF','-1'],
    ['RIF','TOB','-1'],
    ['AMK','VAN','-1'],
    ['AMK','CEF','-1'],
    ['CEF','TOB','-1'],
    ['TOB','VAN','-1'],
    ['NIT','RIF','-1'],
    ['OXA','TRI','1'],
    ['TRI','VAN','1'],
    ['CEF','RIF','-1'],
    ['CEF','NIT','-1'],
    ['NIT','OXA','-1']
])


def search(x):
    for i in range(len(drug)):
        if drug[i] == x:
            return cluster_result[i]

def group_search(x):
    group_list = []
    for i in range(len(GGI)):
        if GGI[i] == x:
            group_list.append(alpha[i][2])
    return group_list


clusters = {
    "0": ['AMK', 'CHL', 'CLA', 'ERY', 'FUS', 'GEN', 'NIT', 'SPE', 'TET', 'TOB'],
    "1": ['CEF', 'OXA'],
    "2": ['RIF'],
    "3": ['CIP', 'LEV', 'NAL'],
    "4": ['VAN'],
    "5": ['TRI']
}

cluster_result = []
for i in range(len(drug)):
    for j in range(len(clusters)):
        if str(drug[i]) in clusters[str(j)]:
            cluster_result.append(j)

GGI = []
for i in range(len(alpha)):
    drug1_ID = search(alpha[i][0])
    drug2_ID = search(alpha[i][1])
    if drug1_ID > drug2_ID:
        GGI.append(str(drug1_ID) + '_' + str(drug2_ID))
    else:
        GGI.append(str(drug2_ID) + '_' + str(drug1_ID))


x = 0
y = 0
k = 6  # number of clusters

for i in range(k-1, -1, -1):
    for j in range(i-1, -1, -1):
        name = str(i) + '_' + str(j)
        count = Counter(group_search(name))
        x += max(count['1'], count['-1'])
        y += len(group_search(name))

print("edge purity =", x / y)